from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce
from .conjugation import Conjugation

pair = lambda x: x if isinstance(x, tuple) else (x, x)


def PPFC(image_size, channels, patch_size, dim):
    """Per-patch-FC
    """
    image_h, image_w = pair(image_size)
    assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
    num_patches = (image_h // patch_size) * (image_w // patch_size)
    num_channels =(patch_size ** 2) * channels 
 

    return nn.Sequential( 
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
        nn.Linear(num_channels, dim),
        Rearrange("b s c -> b (s c)")
        )





class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

class PreNormResidualTokenDiff(nn.Module):
    def __init__(self, dim, fn, i_dim_token, o_dim_token,force_fc_in_skip=0):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        if o_dim_token == i_dim_token and force_fc_in_skip==0:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                Rearrange("n s c -> n c s"),
                nn.Linear(in_features=  i_dim_token, out_features= o_dim_token,bias=None),
                Rearrange("n c s -> n s c"),                
                nn.LayerNorm(dim)
            )

    def forward(self, x):
        return self.fn(self.norm(x)) + self.shortcut(x)





def FeedForward(i_dim, o_dim, expansion_factor = 4 ):
    blocks=[]
    
    if expansion_factor > 0:
        inner_dim = int(o_dim * expansion_factor)
        blocks +=[
            nn.Linear(i_dim, inner_dim),
            nn.GELU(),
            nn.Linear(inner_dim, o_dim),
        ]

    else:
        blocks+=[
            nn.Linear(i_dim, o_dim),
            nn.GELU(),
        ]
    

    return nn.Sequential(*blocks)        






def MLPMixer(*, image_size, channels, patch_size, dim, dim_token, depth, num_classes, expansion_factor = 4, expansion_factor_token = 4,  permute_per_blocks=0, remove_ppfc=0, 
             force_fc_in_skip=0, perm_block_id=-1):

    """MLPMixer and SMixer

    Args:
        image_size (_type_): input image size
        channels (_type_):  input image channel
        patch_size (_type_):  P
        dim (_type_): C 
        dim_token (_type_): S
        depth (_type_): L, number of base blocks ( e.g. token-mlp -> channel-mlp -> token-mlp -> channel-mlp: L=2)
        num_classes (_type_): the number of labels for classifciation
        expansion_factor (int, optional): If set -1, we omit hidden layer in mlp-block (i.e. S-Mixer). Defaults to 4.
        expansion_factor_token (int, optional):  Defaults to 4.
        permute_per_blocks (int, optional): 0: no RP, 2: perform RP. Defaults to 0.
        remove_ppfc (int, optional): _description_. Defaults to 0.
        force_fc_in_skip (int, optional): use skip-layer to treat S_0 to S. Defaults to 1.

    """
       
    image_h, image_w = pair(image_size)
    assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
    num_patches = (image_h // patch_size) * (image_w // patch_size)
    
    preblock = [
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)
    ]
    num_channels =(patch_size ** 2) * channels 
    if not(remove_ppfc == 1):
        preblock += [ nn.Linear(num_channels, dim)]
    
    blocks = []
    if perm_block_id >= 0:
        perm_block_ids = [perm_block_id]
    elif permute_per_blocks >=1:
        perm_block_ids = range(depth)
    else:
        perm_block_ids = []

    for ell in range(depth):
        J = "PT" if ell in perm_block_ids else "T"
        if ell == 0:            
            fn = Conjugation(dim=dim, dim_token=num_patches ,
                             fn= FeedForward(num_patches, dim_token,  expansion_factor_token),
                             o_dim=dim, o_dim_token=dim_token,
                             J=J)
                        
            blocks +=[ PreNormResidualTokenDiff(
                dim=dim, 
                fn= fn,
                i_dim_token=num_patches,
                o_dim_token=dim_token,
                force_fc_in_skip=force_fc_in_skip
                )
            ]
            
        else:
            blocks +=[ PreNormResidual(
                dim, 
                Conjugation(dim, dim_token, 
                            FeedForward(dim_token, dim_token, expansion_factor_token), 
                            J=J)
                )
            ]
        blocks +=[ PreNormResidual(dim, FeedForward(dim,dim, expansion_factor))]



        
    
    return nn.Sequential(
        *preblock,
        *blocks,
        nn.LayerNorm(dim),
        Reduce('b n c -> b c', 'mean'),
        nn.Linear(dim, num_classes)
    )



def SMixer(*, image_size, channels, patch_size, dim, dim_token, depth, num_classes, permute_per_blocks=0, remove_ppfc=0, 
             force_fc_in_skip=0, perm_block_id=-1):
    return MLPMixer( image_size, channels, patch_size, dim, dim_token, depth, num_classes, expansion_factor = -1, expansion_factor_token = -1,  permute_per_blocks=permute_per_blocks, remove_ppfc=remove_ppfc, 
             force_fc_in_skip=force_fc_in_skip, perm_block_id=perm_block_id)

def HiddenMLPMixer( dim, dim_token, depth, expansion_factor=-1, expansion_factor_token=-1, permute_per_blocks=0, num_classes=10):
            
    blocks=[ Rearrange("b (s c) -> b s c", s=dim_token, c=dim)]
    #expansion_factor_token=expansion_factor
    J = "PT" if permute_per_blocks>=1 else "T"
    for ell in range(depth):
        blocks +=[ PreNormResidual(
            dim, 
            Conjugation(dim, dim_token, 
                        FeedForward(dim_token, dim_token, expansion_factor_token), 
                        J=J)
            )
        ]

        blocks +=[ PreNormResidual(dim, FeedForward(dim,dim, expansion_factor))]
    
    return nn.Sequential(
        *blocks,
        nn.LayerNorm(dim),
        Reduce('b n c -> b c', 'mean'),
        nn.Linear(dim, num_classes)
    )


def MLPMixerSep(*, image_size, channels, patch_size, dim, dim_token, depth, num_classes, expansion_factor = 4, expansion_factor_token = 4,  permute_per_blocks=0,
                dim_ppfc=None):
    return nn.Sequential(
        PPFC(image_size=image_size,
                  channels=channels,
                  patch_size=patch_size,
                  dim=dim_ppfc),
        HiddenMLPMixer(
        dim = dim,
        dim_token = dim_token,
        depth = depth,
        num_classes = num_classes,
        permute_per_blocks=  permute_per_blocks,
        expansion_factor=expansion_factor,
        expansion_factor_token=expansion_factor_token
        )        
        
    )

        



